{%- extends 'python/index.py.j2' -%}

{% block header %}
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
# SPDX-License-Identifier: Apache-2.0
{% endblock header %}

{% block body %}
{#-
This template processes a notebook in a single pass to correctly separate
imports (including multi-line imports) from other code. The other code is
then placed inside a main() function.

It works by using a state flag `in_multiline_import` to track whether the
parser is currently inside a multi-line import statement that uses parentheses.

---
How it works:
1. It initializes two empty lists, `imports_code` and `main_code`, to collect
   lines of code.
2. It loops through each code cell once.
3. If a line starts an import, it's added to `imports_code`.
   - If the line also ends with '(', it flips the `in_multiline_import` flag to True.
4. While the flag is True, all subsequent lines (including code and blank lines)
   are added to `imports_code`.
   - When a line consisting only of ')' is found, it flips the flag back to False.
5. Any other code is added to the `main_code` list.
6. Finally, it prints the content of the two lists in the desired structure.
---
-#}

{%- set imports_code = [] -%}
{%- set main_code = [] -%}
{#- We use a list with one item to simulate a mutable variable, as Jinja2 variable scope
    within loops can be tricky. `[false]` means we are not in a multi-line import. -#}
{%- set in_multiline_import = [false] -%}

{%- for cell in nb.cells if cell.cell_type == 'code' -%}
    {%- for line in (cell.source | ipython2python).split('\n') -%}
        {%- set stripped = line.strip() -%}

        {%- if in_multiline_import[0] -%}
            {# CASE 1: We are currently inside a multi-line import statement. #}
            {%- if imports_code.append(line) -%}{%- endif -%}
            {# Check if this line closes the parentheses. This assumes the ')' is on its own line. #}
            {%- if stripped == ')' -%}
                {# Set the flag back to false. #}
                {%- set _ = in_multiline_import.pop() -%}
                {%- set _ = in_multiline_import.append(false) -%}
            {%- endif -%}

        {%- elif stripped.startswith(('import ', 'from ')) -%}
            {# CASE 2: This is a new import statement. #}
            {%- if imports_code.append(line) -%}{%- endif -%}
            {# Check if it's the start of a multi-line import.
                This heuristic assumes the opening parenthesis is at the end of the line. #}
            {%- if stripped.endswith('(') -%}
                {# Set the flag to true. #}
                {%- set _ = in_multiline_import.pop() -%}
                {%- set _ = in_multiline_import.append(true) -%}
            {%- endif -%}

        {%- elif stripped -%}
            {# CASE 3: This is a regular line of code. Add it to the main function body. #}
            {%- if main_code.append("    " + line) -%}{%- endif -%}

        {%- else -%}
            {# CASE 4: This is an empty line. Only add it if it's between code, not at the start or end. #}
            {# Also, prevent multiple consecutive blank lines. #}
            {%- if main_code and main_code[-1].strip() != "" -%}
                {# Add blank line to the main code block if it's not empty and last line wasn't blank. #}
                {%- if main_code.append("") -%}{%- endif -%}
            {%- elif imports_code and not main_code and imports_code[-1] != "" -%}
                {# Add blank line to the import block if it's not empty and we haven't started main yet. #}
                {%- if imports_code.append("") -%}{%- endif -%}
            {%- endif -%}
        {%- endif -%}

    {%- endfor -%}
{%- endfor -%}

{#- Output the collected imports. Strip any trailing whitespace. #}
{{ (imports_code | join('\n')).rstrip() }}

{#- Add separation between imports and the main function, but only if both exist and have content. #}
{% if (imports_code | join | trim) and (main_code | join | trim) %}

{% endif %}
def main():
{#- Output the collected main code. #}
{{ (main_code | join('\n')).rstrip() }}

if __name__ == "__main__":
    main()
{% endblock body %}
